from typing import Any
import cv2
import numpy as np
from yolov7 import YOLOv7
import gradio as gr
from PIL import Image

class Inference:
    def setup_models(self, model_path, labels_path, engine_path):
        yolo = YOLOv7(
            model_path,
            labels_path,
            engine_path
        )
        return yolo
    
    def __init__(self, model_path, labels_path, engine_path):
        self.model = self.setup_models(
            model_path,
            labels_path,
            engine_path
        )

    def __call__(self, frame: np.ndarray, conf_threshold: float, nms_threshold: float, *args: Any, **kwds: Any) -> Any:
        boxes, scores, class_ids = self.model(frame, conf_threshold, nms_threshold)
        return boxes, scores, class_ids

infer1 = Inference(
    "models/firesmoke.onnx",
    "models/labels.txt",
    "firesmoke.trt"
)

infer2 = Inference(
    "models/firesmoke-henry.onnx",
    "models/labels.txt",
    "firesmoke-henry.trt"
)

def run(content_img, conf_threshold, nms_threshold):
    content_img = cv2.cvtColor(np.array(content_img), cv2.COLOR_RGB2BGR)

    boxes1, scores1, class_ids1 = infer1(content_img, conf_threshold, nms_threshold)
    boxes2, scores2, class_ids2 = infer2(content_img, conf_threshold, nms_threshold)
    img1 = content_img.copy()
    img2 = content_img.copy()

    if len(boxes1) > 0:
        for box, score, class_id in zip(boxes1, scores1, class_ids1):
            x1 = int(box[0])
            y1 = int(box[1])
            x2 = int(box[2])
            y2 = int(box[3])
            cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.rectangle(img1, (x1, y1-20), (x1+100, y1), (0, 0, 255), -1)
            cv2.putText(img1, "{}:{:.2f}".format(class_id, score), (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1)

    if len(boxes2) > 0:
        for box, score, class_id in zip(boxes2, scores2, class_ids2):
            x1 = int(box[0])
            y1 = int(box[1])
            x2 = int(box[2])
            y2 = int(box[3])
            cv2.rectangle(img2, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.rectangle(img2, (x1, y1-20), (x1+100, y1), (0, 0, 255), -1)
            cv2.putText(img2, "{}:{:.2f}".format(class_id, score), (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1)

    img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
    img1 = Image.fromarray(img1)
    img2 = Image.fromarray(img2)

    return img1, img2

if __name__ == '__main__':
    style = gr.Interface(
        fn=run, 
        inputs=[
            gr.Image(label='Input Image'),
            gr.Slider(minimum=0.05, maximum=1, step=0.05, label="Confidence Threshold", default=0.2),
            gr.Slider(minimum=0.05, maximum=1, step=0.05, label="NMS Threshold", default=0.5),
        ], 
        outputs=[
            gr.Image(
                type="pil",
                label="Finetuned"
            ),
            gr.Image(
                type="pil",
                label="Finetuned + New Data"
            ),
        ], 
        examples=[
            ['examples/fire1.jpg', 0.2, 0.5],
            ['examples/fire2.jpg', 0.2, 0.5],
            ['examples/fire3.jpg', 0.15, 0.5]
        ]
    )
    style.launch()